{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementation of Accurate Binary Convolution Layer\n",
    "[Original Paper](https://arxiv.org/abs/1711.11294)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from __future__ import division, print_function\n",
    "import tensorflow as tf\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The inspiration for this network is the use of Deep Neural Networks for real-time object recognition. Currently available **Convolution Layers** require large amount of computation power at runtime and that hinders the use of very deep networks in embedded systems or ASICs. Xiaofan Lin, Cong Zhao, and Wei Pan presented a way to convert Convolution Layers to **Binary Convolution Layers** for faster realtime computation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Approximating Convolution weights using binary weights\n",
    "Here the hope is to approximate $\\mathbf{W}\\in\\mathbb{R}^{w*h*c_{in}*c_{out}}$ using $\\alpha_1\\mathbf{B_1}+\\alpha_2\\mathbf{B_2}+...+\\alpha_m\\mathbf{B_m}$ where $\\mathbf{B_1}, \\mathbf{B_2}, ..., \\mathbf{B_m}\\in\\mathbb{R}^{w*h*c_{in}*c_{out}}$ and $\\alpha_1, \\alpha_2, ..., \\alpha_m\\in\\mathbb{R}^1$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Conversion from convolution filter to binary filter\n",
    "Let's implement the conversion of convolution filter to binary convolution filters first.\n",
    "To approximate $\\mathbf{W}$ with $\\alpha_1\\mathbf{B_1}+\\alpha_2\\mathbf{B_2}+...+\\alpha_m\\mathbf{B_m}$ we'll use the equation from the paper $\\mathbf{B_i}=\\operatorname{sign}(\\bar{\\mathbf{W}} + \\mu_i\\operatorname{std}(\\mathbf{W}))$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll need mean and standard deviation of the complete convolution filters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_mean_stddev(input_tensor):\n",
    "    with tf.name_scope('mean_stddev_cal'):\n",
    "        mean, variance = tf.nn.moments(input_tensor, axes=range(len(input_tensor.get_shape())))\n",
    "        stddev = tf.sqrt(variance, name=\"standard_deviation\")\n",
    "        return mean, stddev"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to spread the standard deviation by the number of filters being used as in the original paper\n",
    "$\\mu_i= -1 + (i - 1)\\frac{2}{\\mathbf{M} - 1}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# TODO: Allow shift parameters to be learnable\n",
    "def get_shifted_stddev(stddev, no_filters):\n",
    "    with tf.name_scope('shifted_stddev'):\n",
    "        spreaded_deviation = -1. + (2./(no_filters - 1)) * tf.convert_to_tensor(range(no_filters),\n",
    "                                                                                dtype=tf.float32)\n",
    "        return spreaded_deviation * stddev"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we can get the values of $\\mathbf{B_{i}s}$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_binary_filters(convolution_filters, no_filters, name=None):\n",
    "    with tf.name_scope(name, default_name=\"get_binary_filters\"):\n",
    "        mean, stddev = get_mean_stddev(convolution_filters)\n",
    "        shifted_stddev = get_shifted_stddev(stddev, no_filters)\n",
    "        \n",
    "        # Normalize the filters by subtracting mean from them\n",
    "        mean_adjusted_filters = convolution_filters - mean\n",
    "        \n",
    "        # Tiling filters to match the number of filters\n",
    "        expanded_filters = tf.expand_dims(mean_adjusted_filters, axis=0, name=\"expanded_filters\")\n",
    "        tiled_filters = tf.tile(expanded_filters, [no_filters] + [1] * len(convolution_filters.get_shape()),\n",
    "                                name=\"tiled_filters\")\n",
    "        \n",
    "        # Similarly tiling spreaded stddev to match the shape of tiled_filters\n",
    "        expanded_stddev = tf.reshape(shifted_stddev, [no_filters] + [1] * len(convolution_filters.get_shape()),\n",
    "                                     name=\"expanded_stddev\")\n",
    "        \n",
    "        binarized_filters = tf.sign(tiled_filters + expanded_stddev, name=\"binarized_filters\")\n",
    "        return binarized_filters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Calculating alphas\n",
    "Now, we can calculate alphas using the *binary filters* and *convolution filters* by minimizing the *squared difference*\n",
    "$\\|\\mathbf{W}-\\mathbf{B}\\alpha\\|^2$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_alphas(convolution_filters, binary_filters, no_filters, name=None):\n",
    "    with tf.name_scope(name, \"get_alphas\"):\n",
    "        # Reshaping convolution filters to be one dimensional and binary filters to be of [no_filters, -1] dimension\n",
    "        reshaped_convolution_filters = tf.reshape(convolution_filters, [-1], name=\"reshaped_convolution_filters\")\n",
    "        reshaped_binary_filters = tf.reshape(binary_filters, [no_filters, -1],\n",
    "                                             name=\"reshaped_binary_filters\")\n",
    "        \n",
    "        # Creating variable for alphas\n",
    "        alphas = tf.Variable(tf.random_normal(shape=(no_filters, 1), mean=1.0, stddev=0.1), name=\"alphas\")\n",
    "        \n",
    "        # Calculating W*alpha\n",
    "        weighted_sum_filters = tf.reduce_sum(tf.multiply(alphas, reshaped_binary_filters),\n",
    "                                             axis=0, name=\"weighted_sum_filters\")\n",
    "        \n",
    "        # Defining loss\n",
    "        error = tf.square(reshaped_convolution_filters - weighted_sum_filters, name=\"alphas_error\")\n",
    "        loss = tf.reduce_mean(error, axis=0, name=\"alphas_loss\")\n",
    "        \n",
    "        # Defining optimizer\n",
    "        training_op = tf.train.AdamOptimizer().minimize(loss, var_list=[alphas],\n",
    "                                                        name=\"alphas_training_op\")\n",
    "        \n",
    "        return alphas, training_op, loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Creating ApproxConv using the binary filters\n",
    "$\\mathbf{O}=\\sum\\limits_{m=1}^M\\alpha_m\\operatorname{Conv}(\\mathbf{B}_m, \\mathbf{A})$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As in mentioned in the paper, it is better to train the network first with simple Convolution networks and then convert the filters into the binary filters, allowing original filters to be trained."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def ApproxConv(no_filters, convolution_filters, convolution_biases=None,\n",
    "               strides=(1, 1), padding=\"VALID\", name=None):\n",
    "    with tf.name_scope(name, \"ApproxConv\"):\n",
    "        # Creating variables from input convolution filters and convolution biases\n",
    "        filters = tf.Variable(convolution_filters, dtype=tf.float32, name=\"filters\")\n",
    "        if convolution_biases is None:\n",
    "            biases = 0.\n",
    "        else:\n",
    "            biases = tf.Variable(convolution_biases, dtype=tf.float32, name=\"biases\")\n",
    "        \n",
    "        # Creating binary filters\n",
    "        binary_filters = get_binary_filters(filters, no_filters)\n",
    "        \n",
    "        # Getting alphas\n",
    "        alphas, alphas_training_op, alphas_loss = get_alphas(filters, binary_filters,\n",
    "                                                             no_filters)\n",
    "        \n",
    "        # Defining function for closure to accept multiple inputs with same filters\n",
    "        def ApproxConvLayer(input_tensor, name=None):\n",
    "            with tf.name_scope(name, \"ApproxConv_Layer\"):\n",
    "                # Reshaping alphas to match the input tensor\n",
    "                reshaped_alphas = tf.reshape(alphas,\n",
    "                                             shape=[no_filters] + [1] * len(input_tensor.get_shape()),\n",
    "                                             name=\"reshaped_alphas\")\n",
    "                \n",
    "                # Calculating convolution for each binary filter\n",
    "                approxConv_outputs = []\n",
    "                for index in range(no_filters):\n",
    "                    # Binary convolution\n",
    "                    this_conv = tf.nn.conv2d(input_tensor, binary_filters[index],\n",
    "                                             strides=(1,) + strides + (1,),\n",
    "                                             padding=padding)\n",
    "                    approxConv_outputs.append(this_conv + biases)\n",
    "                conv_outputs = tf.convert_to_tensor(approxConv_outputs, dtype=tf.float32,\n",
    "                                                    name=\"conv_outputs\")\n",
    "                \n",
    "                # Summing up each of the binary convolution\n",
    "                ApproxConv_output = tf.reduce_sum(tf.multiply(conv_outputs, reshaped_alphas), axis=0)\n",
    "                \n",
    "                return ApproxConv_output\n",
    "        \n",
    "        return alphas_training_op, ApproxConvLayer, alphas_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multiple binary activations and bitwise convolution\n",
    "Now, convolution can be achieved using just the summation operations by using the ApproxConv layers. But the paper suggests something even better. We can even bypass the summation through bitwise operations only, if the input to the convolution layer is also binarized.\n",
    "For that the authors suggests that an input can be binarized (creating multiple inputs) by shifting the inputs and binarizing them."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, the input is clipped between 0. and 1. using multiple shift parameters $\\nu$, learnable by the network  \n",
    "$\\operatorname{h_{\\nu}}(x)=\\operatorname{clip}(x + \\nu, 0, 1)$  \n",
    "  \n",
    "Then using the following function it is binarized  \n",
    "$\\operatorname{H_{\\nu}}(\\mathbf{R})=2\\mathbb{I}_{\\operatorname{h_{\\nu}}(\\mathbf{R})\\geq0.5}-1$\n",
    "\n",
    "The above function can be implemented as  \n",
    "$\\operatorname{H_{\\nu}}(\\mathbf{R})=\\operatorname{sign}(\\mathbf{R} - 0.5)$\n",
    "\n",
    "Now, after calculating the **ApproxConv** over each separated input, their weighted summation can be taken using trainable paramters $\\beta s$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def ABC(convolution_filters, convolution_biases=None, no_binary_filters=5, no_ApproxConvLayers=5,\n",
    "        strides=(1, 1), padding=\"VALID\", name=None):\n",
    "    with tf.name_scope(name, \"ABC\"):\n",
    "        # Creating variables shift parameters and weighted sum parameters (betas)\n",
    "        shift_parameters = tf.Variable(tf.constant(0., shape=(no_ApproxConvLayers, 1)), dtype=tf.float32,\n",
    "                                       name=\"shift_parameters\")\n",
    "        betas = tf.Variable(tf.constant(1., shape=(no_ApproxConvLayers, 1)), dtype=tf.float32,\n",
    "                            name=\"betas\")\n",
    "        \n",
    "        # Instantiating the ApproxConv Layer\n",
    "        alphas_training_op, ApproxConvLayer, alphas_loss = ApproxConv(no_binary_filters,\n",
    "                                                                      convolution_filters, convolution_biases,\n",
    "                                                                      strides, padding)\n",
    "        \n",
    "        def ABCLayer(input_tensor, name=None):\n",
    "            with tf.name_scope(name, \"ABCLayer\"):\n",
    "                # Reshaping betas to match the input tensor\n",
    "                reshaped_betas = tf.reshape(betas,\n",
    "                                            shape=[no_ApproxConvLayers] + [1] * len(input_tensor.get_shape()),\n",
    "                                            name=\"reshaped_betas\")\n",
    "                \n",
    "                # Calculating ApproxConv for each shifted input\n",
    "                ApproxConv_layers = []\n",
    "                for index in range(no_ApproxConvLayers):\n",
    "                    # Shifting and binarizing input\n",
    "                    shifted_input = tf.clip_by_value(input_tensor + shift_parameters[index], 0., 1.,\n",
    "                                                     name=\"shifted_input_\" + str(index))\n",
    "                    binarized_activation = tf.sign(shifted_input - 0.5)\n",
    "                    \n",
    "                    # Passing through the ApproxConv layer\n",
    "                    ApproxConv_layers.append(ApproxConvLayer(binarized_activation))\n",
    "                ApproxConv_output = tf.convert_to_tensor(ApproxConv_layers, dtype=tf.float32,\n",
    "                                                         name=\"ApproxConv_output\")\n",
    "                \n",
    "                # Taking the weighted sum using the betas\n",
    "                ABC_output = tf.reduce_sum(tf.multiply(ApproxConv_output, reshaped_betas), axis=0)\n",
    "                return ABC_output\n",
    "        \n",
    "        return alphas_training_op, ABCLayer, alphas_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Testing\n",
    "Let's just test our network using MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.\n",
      "Extracting /tmp/data/train-images-idx3-ubyte.gz\n",
      "Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.\n",
      "Extracting /tmp/data/train-labels-idx1-ubyte.gz\n",
      "Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.\n",
      "Extracting /tmp/data/t10k-images-idx3-ubyte.gz\n",
      "Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.\n",
      "Extracting /tmp/data/t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "# Importing data\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "!mkdir -p /tmp/data\n",
    "mnist = input_data.read_data_sets(\"/tmp/data/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Defining utils function\n",
    "def weight_variable(shape, name=\"weight\"):\n",
    "    initial = tf.truncated_normal(shape, stddev=0.1)\n",
    "    return tf.Variable(initial, name=name)\n",
    "\n",
    "def bias_variable(shape, name=\"bias\"):\n",
    "    initial = tf.constant(0.1, shape=shape)\n",
    "    return tf.Variable(initial, name=name)\n",
    "\n",
    "def conv2d(x, W):\n",
    "    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')\n",
    "\n",
    "def max_pool_2x2(x):\n",
    "    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],\n",
    "                        strides=[1, 2, 2, 1], padding='SAME')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Creating the graph\n",
    "without_ABC_graph = tf.Graph()\n",
    "with without_ABC_graph.as_default():\n",
    "    # Defining inputs\n",
    "    x = tf.placeholder(dtype=tf.float32)\n",
    "    x_image = tf.reshape(x, [-1, 28, 28, 1])\n",
    "    \n",
    "     # Convolution Layer 1\n",
    "    W_conv1 = weight_variable(shape=([5, 5, 1, 32]), name=\"W_conv1\")\n",
    "    b_conv1 = bias_variable(shape=[32], name=\"b_conv1\")\n",
    "    conv1 = (conv2d(x_image, W_conv1) + b_conv1)\n",
    "    pool1 = max_pool_2x2(conv1)\n",
    "    bn_conv1 = tf.layers.batch_normalization(pool1, axis=-1, name=\"batchNorm1\")\n",
    "    h_conv1 = tf.nn.relu(bn_conv1)\n",
    "\n",
    "    # Convolution Layer 2\n",
    "    W_conv2 = weight_variable(shape=([5, 5, 32, 64]), name=\"W_conv2\")\n",
    "    b_conv2 = bias_variable(shape=[64], name=\"b_conv2\")\n",
    "    conv2 = (conv2d(h_conv1, W_conv2) + b_conv2)\n",
    "    pool2 = max_pool_2x2(conv2)\n",
    "    bn_conv2 = tf.layers.batch_normalization(pool2, axis=-1, name=\"batchNorm2\")\n",
    "    h_conv2 = tf.nn.relu(bn_conv2)\n",
    "\n",
    "    # Flat the conv2 output\n",
    "    h_conv2_flat = tf.reshape(h_conv2, shape=(-1, 7*7*64))\n",
    "\n",
    "    # Dense layer1\n",
    "    W_fc1 = weight_variable([7 * 7 * 64, 1024])\n",
    "    b_fc1 = bias_variable([1024])\n",
    "    h_fc1 = tf.nn.relu(tf.matmul(h_conv2_flat, W_fc1) + b_fc1)\n",
    "\n",
    "    # Dropout\n",
    "    keep_prob = tf.placeholder(tf.float32)\n",
    "    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)\n",
    "\n",
    "    # Output layer\n",
    "    W_fc2 = weight_variable([1024, 10])\n",
    "    b_fc2 = bias_variable([10])\n",
    "\n",
    "    y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2\n",
    "    \n",
    "    # Labels\n",
    "    y = tf.placeholder(tf.int32, [None])\n",
    "    y_ = tf.one_hot(y, 10)\n",
    "    \n",
    "    # Defining optimizer and loss\n",
    "    cross_entropy = tf.reduce_mean(\n",
    "        tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))\n",
    "    train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)\n",
    "    correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))\n",
    "    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
    "    \n",
    "    # Initializer\n",
    "    graph_init = tf.global_variables_initializer()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's just define a dictionary to hold the numpy values of the calculated parameters of the network, so that we can feed it directly to our custom network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Defining variables to save. These will be fed to our custom layer\n",
    "variables_to_save = {\"W_conv1\": W_conv1,\n",
    "                     \"b_conv1\": b_conv1,\n",
    "                     \"W_conv2\": W_conv2,\n",
    "                     \"b_conv2\": b_conv2,\n",
    "                     \"W_fc1\": W_fc1,\n",
    "                     \"b_fc1\": b_fc1,\n",
    "                     \"W_fc2\": W_fc2,\n",
    "                     \"b_fc2\": b_fc2}\n",
    "values = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1  Val accuracy: 88.0000%  Loss: 0.432063\n",
      "Epoch: 2  Val accuracy: 98.0000%  Loss: 0.128601\n",
      "Epoch: 3  Val accuracy: 96.0000%  Loss: 0.197146\n",
      "Epoch: 4  Val accuracy: 96.0000%  Loss: 0.111511\n",
      "Epoch: 5  Val accuracy: 92.0000%  Loss: 0.232009\n"
     ]
    }
   ],
   "source": [
    "n_epochs = 5\n",
    "batch_size = 32\n",
    "        \n",
    "with tf.Session(graph=without_ABC_graph) as sess:\n",
    "    sess.run(graph_init)\n",
    "    for epoch in range(n_epochs):\n",
    "        for iteration in range(1, 200 + 1):\n",
    "            batch = mnist.train.next_batch(50)\n",
    "            \n",
    "            # Run operation and calculate loss\n",
    "            _, loss_train = sess.run([train_step, cross_entropy],\n",
    "                                     feed_dict={x: batch[0], y: batch[1], keep_prob: 0.5})\n",
    "            print(\"\\rIteration: {}/{} ({:.1f}%)  Loss: {:.5f}\".format(\n",
    "                      iteration, 200,\n",
    "                      iteration * 100 / 200,\n",
    "                      loss_train),\n",
    "                  end=\"\")\n",
    "\n",
    "        # At the end of each epoch,\n",
    "        # measure the validation loss and accuracy:\n",
    "        loss_vals = []\n",
    "        acc_vals = []\n",
    "        for iteration in range(1, 200 + 1):\n",
    "            X_batch, y_batch = mnist.validation.next_batch(batch_size)\n",
    "            acc_val, loss_val = sess.run([accuracy, cross_entropy],\n",
    "                                     feed_dict={x: batch[0], y: batch[1], keep_prob: 1.0})\n",
    "            loss_vals.append(loss_val)\n",
    "            acc_vals.append(acc_val)\n",
    "            print(\"\\rEvaluating the model: {}/{} ({:.1f}%)\".format(iteration, 200,\n",
    "                iteration * 100 / 200),\n",
    "                  end=\" \" * 10)\n",
    "        loss_val = np.mean(loss_vals)\n",
    "        acc_val = np.mean(acc_vals)\n",
    "        print(\"\\rEpoch: {}  Val accuracy: {:.4f}%  Loss: {:.6f}\".format(\n",
    "            epoch + 1, acc_val * 100, loss_val))\n",
    "        \n",
    "    # On completion of training, save the variables to be fed to custom model\n",
    "    for var_name in variables_to_save:\n",
    "        values[var_name] = sess.run(variables_to_save[var_name])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's build our model now"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "custom_graph = tf.Graph()\n",
    "with custom_graph.as_default():\n",
    "    alphas_training_operations = []\n",
    "    \n",
    "    # Inputs\n",
    "    x = tf.placeholder(dtype=tf.float32)\n",
    "    x_image = tf.reshape(x, [-1, 28, 28, 1])\n",
    "    \n",
    "    # Convolution Layer 1\n",
    "    W_conv1 = tf.Variable(values[\"W_conv1\"], name=\"W_conv1\")\n",
    "    b_conv1 = tf.Variable(values[\"b_conv1\"], name=\"b_conv1\")\n",
    "    alphas_training_op1, ABCLayer1, alphas_loss1 = ABC(W_conv1, b_conv1,\n",
    "                                                       no_binary_filters=5,\n",
    "                                                       no_ApproxConvLayers=5,\n",
    "                                                       padding=\"SAME\")\n",
    "    alphas_training_operations.append(alphas_training_op1)\n",
    "    conv1 = ABCLayer1(x_image)\n",
    "    pool1 = max_pool_2x2(conv1)\n",
    "    bn_conv1 = tf.layers.batch_normalization(pool1, axis=-1)\n",
    "    h_conv1 = tf.nn.relu(bn_conv1)\n",
    "\n",
    "    # Convolution Layer 2\n",
    "    W_conv2 = tf.Variable(values[\"W_conv2\"], name=\"W_conv2\")\n",
    "    b_conv2 = tf.Variable(values[\"b_conv2\"], name=\"b_conv2\")\n",
    "    alphas_training_op2, ABCLayer2, alphas_loss2 = ABC(W_conv2, b_conv2,\n",
    "                                                       no_binary_filters=5,\n",
    "                                                       no_ApproxConvLayers=5,\n",
    "                                                       padding=\"SAME\")\n",
    "    alphas_training_operations.append(alphas_training_op2)\n",
    "    conv2 = ABCLayer2(h_conv1)\n",
    "    pool2 = max_pool_2x2(conv2)\n",
    "    bn_conv2 = tf.layers.batch_normalization(pool2, axis=-1)\n",
    "    h_conv2 = tf.nn.relu(bn_conv2)\n",
    "\n",
    "    # Flat the conv2 output\n",
    "    h_conv2_flat = tf.reshape(h_conv2, shape=(-1, 7*7*64))\n",
    "\n",
    "    # Dense layer1\n",
    "    W_fc1 = weight_variable([7 * 7 * 64, 1024])\n",
    "    b_fc1 = bias_variable([1024])\n",
    "    h_fc1 = tf.nn.relu(tf.matmul(h_conv2_flat, W_fc1) + b_fc1)\n",
    "\n",
    "    # Dropout\n",
    "    keep_prob = tf.placeholder(tf.float32)\n",
    "    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)\n",
    "\n",
    "    # Output layer\n",
    "    W_fc2 = weight_variable([1024, 10])\n",
    "    b_fc2 = bias_variable([10])\n",
    "    y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2\n",
    "    \n",
    "    # Labels\n",
    "    y = tf.placeholder(tf.int32, [None])\n",
    "    y_ = tf.one_hot(y, 10)\n",
    "    \n",
    "    # Defining optimizer and loss\n",
    "    cross_entropy = tf.reduce_mean(\n",
    "        tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))\n",
    "    train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)\n",
    "    correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))\n",
    "    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
    "    \n",
    "    graph_init = tf.global_variables_initializer()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1  Val accuracy: 88.0000%  Loss: 6.530759\n",
      "Epoch: 2  Val accuracy: 86.0000%  Loss: 4.208882\n",
      "Epoch: 3  Val accuracy: 92.0000%  Loss: 1.455365\n",
      "Epoch: 4  Val accuracy: 92.0000%  Loss: 0.708834\n",
      "Epoch: 5  Val accuracy: 86.0000%  Loss: 0.366106\n"
     ]
    }
   ],
   "source": [
    "n_epochs = 5\n",
    "batch_size = 32\n",
    "alpha_training_epochs = 200\n",
    "        \n",
    "with tf.Session(graph=custom_graph) as sess:\n",
    "    sess.run(graph_init)\n",
    "    for epoch in range(n_epochs):\n",
    "        for iteration in range(1, 200 + 1):\n",
    "            # Training alphas\n",
    "            for alpha_training_op in alphas_training_operations:\n",
    "                for alpha_epoch in range(alpha_training_epochs):\n",
    "                    sess.run(alpha_training_op)\n",
    "            \n",
    "            batch = mnist.train.next_batch(50)\n",
    "            \n",
    "            # Run operation and calculate loss\n",
    "            _, loss_train = sess.run([train_step, cross_entropy],\n",
    "                                     feed_dict={x: batch[0], y: batch[1], keep_prob: 0.5})\n",
    "            print(\"\\rIteration: {}/{} ({:.1f}%)  Loss: {:.5f}\".format(\n",
    "                      iteration, 200,\n",
    "                      iteration * 100 / 200,\n",
    "                      loss_train),\n",
    "                  end=\"\")\n",
    "\n",
    "        # At the end of each epoch,\n",
    "        # measure the validation loss and accuracy:\n",
    "        \n",
    "        # Training alphas\n",
    "        for alpha_training_op in alphas_training_operations:\n",
    "            for alpha_epoch in range(alpha_training_epochs):\n",
    "                sess.run(alpha_training_op)\n",
    "                    \n",
    "        loss_vals = []\n",
    "        acc_vals = []\n",
    "        for iteration in range(1, 200 + 1):            \n",
    "            X_batch, y_batch = mnist.validation.next_batch(batch_size)\n",
    "            acc_val, loss_val = sess.run([accuracy, cross_entropy],\n",
    "                                     feed_dict={x: batch[0], y: batch[1], keep_prob: 1.0})\n",
    "            loss_vals.append(loss_val)\n",
    "            acc_vals.append(acc_val)\n",
    "            print(\"\\rEvaluating the model: {}/{} ({:.1f}%)\".format(iteration, 200,\n",
    "                iteration * 100 / 200),\n",
    "                  end=\" \" * 10)\n",
    "        loss_val = np.mean(loss_vals)\n",
    "        acc_val = np.mean(acc_vals)\n",
    "        print(\"\\rEpoch: {}  Val accuracy: {:.4f}%  Loss: {:.6f}\".format(\n",
    "            epoch + 1, acc_val * 100, loss_val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow",
   "language": "python",
   "name": "tensorflow"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
